


#include "qelib.h"
#include <time.h>
#include <stdio.h>
#include <stdlib.h>
#include <signal.h>
#include <getopt.h>

#if defined(CONFIG_X86_AVX)
#include <immintrin.h>
#endif


#define GEMMOP_LOGNAME    "gemm"
#define gemm_debug(...)   qelog_debug(GEMMOP_LOGNAME,   __VA_ARGS__)
#define gemm_info(...)    qelog_info(GEMMOP_LOGNAME,    __VA_ARGS__)
#define gemm_notice(...)  qelog_notice(GEMMOP_LOGNAME,  __VA_ARGS__)
#define gemm_warning(...) qelog_warning(GEMMOP_LOGNAME, __VA_ARGS__)
#define gemm_error(...)   qelog_error(GEMMOP_LOGNAME,   __VA_ARGS__)


#define GEMM_EPOCHS     (1)
#define GEMM_MULT_DUMP  (0)


typedef enum {
    GEMM_OP_0 = 0,
    GEMM_OP_1,
    GEMM_OP_2,
    GEMM_OP_3,
    GEMM_OP_4,
    GEMM_OP_5,
    GEMM_OP_6,
    GEMM_OP_7,
    GEMM_OP_SIZEOF,
} gemm_op_e;

typedef float mat_dtype;

typedef struct {
    int rows;
    int cols;
    mat_dtype *blob;
} matrix;

typedef struct {
    char *name;
    void (*impl)(matrix *, matrix *, matrix *);
    char *desc;
} gemm_handle;

typedef struct {
    qe_u32 ctime;
    qe_u32 flops;
} gemm_epoch_metric;

typedef struct {
    qe_u32 max_ctime;
    qe_u32 min_ctime;
    qe_u32 max_flops;
    qe_u32 min_flops;
    qe_u32 mean_ctime;
    qe_u32 num_epochs;
    gemm_epoch_metric *epochs;
} gemm_metric;

static void gemm_metric_destory(gemm_metric *metric)
{
    qe_free(metric->epochs);
}

static void gemm_metric_static(gemm_metric *metric)
{
    int i;
    qe_u32 sum = 0;

    for (i=0; i<metric->num_epochs; i++) {
        sum += metric->epochs[i].ctime;
    }

    metric->mean_ctime = sum / metric->num_epochs;
}

static void gemm_metric_update(gemm_metric *metric, int epoch, qe_u32 ctime)
{
    metric->epochs[epoch].ctime = ctime;

    if (ctime > metric->max_ctime) {
        metric->max_ctime = ctime;
    } else if (ctime < metric->min_ctime) {
        metric->min_ctime = ctime;
    }
}

static qe_ret gemm_metric_init(gemm_metric *metric, int epochs)
{
    metric->epochs = qe_malloc(epochs * sizeof(gemm_epoch_metric));
    if (!metric->epochs) {
        gemm_error("alloc epochs error");
        return qe_err_mem;
    }

    qe_memset(metric->epochs, 0x0, epochs * sizeof(gemm_epoch_metric));

    metric->max_ctime  = 0;
    metric->min_ctime  = 0xFFFFFFFF;
    metric->max_flops  = 0;
    metric->min_flops  = 0xFFFFFFFF;
    metric->mean_ctime = 0;
    metric->num_epochs = epochs;
}

static matrix *matrix_create(int rows, int cols)
{
    matrix *m = qe_malloc(sizeof(matrix) + rows * cols * sizeof(mat_dtype));
    m->blob   = (mat_dtype *)((char *)m + sizeof(matrix));
    m->rows   = rows;
    m->cols   = cols;
}

static void matrix_set_zero(matrix *m)
{
    qe_memset(m->blob, 0x0, m->rows * m->cols * sizeof(mat_dtype));
}

static void matrix_set_random(matrix *m)
{
    int i, j;
    int max = 10;
    int min = 0;
    static qe_bool init = qe_false;

    if (init == qe_false) {
        time_t now;
        srand(time(&now));
        init = qe_true;
    }

    for (i=0; i<m->rows; i++) {
        for (j=0; j<m->cols; j++) {
            m->blob[i*m->cols + j] = (mat_dtype)(rand() % max + min);
        }
    }
}

static void matrix_dump(matrix *m)
{
    int i, j;
    for (i=0; i<m->rows; i++) {
        printf("[");
        for (j=0; j<m->cols; j++) {
            if (j == (m->cols-1)) {
                printf("%4d]", (qe_u32)m->blob[i*m->cols + j]);
            } else {
                printf("%4d ", (qe_u32)m->blob[i*m->cols + j]);
            }
        }
        printf("\n");
    }
}

#if defined(CONFIG_X86_AVX)
static void matrix_mult_7(matrix *a, matrix *b, matrix *c)
{
    int m = a->rows;
    int n = b->cols;
    int k = a->cols;

    for (int i=0; i<m; i++) {
        
        for (int j=0; j<n; j+=8) {

            __m256 vc = _mm256_setzero_ps();
            
            for (int p=0; p<k; p++) {
                __m256 tmp = _mm256_loadu_ps(b->blob + p*n + j);
                vc = _mm256_add_ps(vc,
                     _mm256_mul_ps(_mm256_set1_ps(a->blob[i*k+p]),
                     _mm256_loadu_ps(b->blob + p*n + j)));
            }
            _mm256_storeu_ps(c->blob + i*n + j, vc);
        }
    }
}
#endif

static void matrix_mult_6(matrix *a, matrix *b, matrix *c)
{
    int i, j, p;
    int m = a->rows;
    int n = b->cols;
    int k = a->cols;

    for (i=0; i<m; i+=8) {
        
        for (j=0; j<n; j+=8) {

            mat_dtype temp_c[8][8];

            for (p=0; p<k; p++) {

                mat_dtype temp_a[8];
                mat_dtype temp_b[8];

                temp_a[0] = a->blob[i*k+p+0];
                temp_a[1] = a->blob[i*k+p+1];
                temp_a[2] = a->blob[i*k+p+2];
                temp_a[3] = a->blob[i*k+p+3];
                temp_a[4] = a->blob[i*k+p+4];
                temp_a[5] = a->blob[i*k+p+5];
                temp_a[6] = a->blob[i*k+p+6];
                temp_a[7] = a->blob[i*k+p+7];

                temp_b[0] = b->blob[p*n+j+0];
                temp_b[1] = b->blob[p*n+j+1];
                temp_b[2] = b->blob[p*n+j+2];
                temp_b[3] = b->blob[p*n+j+3];
                temp_b[4] = b->blob[p*n+j+4];
                temp_b[5] = b->blob[p*n+j+5];
                temp_b[6] = b->blob[p*n+j+6];
                temp_b[7] = b->blob[p*n+j+7];

                temp_c[0][0] += temp_a[0] * temp_b[0];
                temp_c[0][1] += temp_a[0] * temp_b[1];
                temp_c[0][2] += temp_a[0] * temp_b[2];
                temp_c[0][3] += temp_a[0] * temp_b[3];
                temp_c[0][4] += temp_a[0] * temp_b[4];
                temp_c[0][5] += temp_a[0] * temp_b[5];
                temp_c[0][6] += temp_a[0] * temp_b[6];
                temp_c[0][7] += temp_a[0] * temp_b[7];

                temp_c[1][0] += temp_a[1] * temp_b[0];
                temp_c[1][1] += temp_a[1] * temp_b[1];
                temp_c[1][2] += temp_a[1] * temp_b[2];
                temp_c[1][3] += temp_a[1] * temp_b[3];
                temp_c[1][4] += temp_a[1] * temp_b[4];
                temp_c[1][5] += temp_a[1] * temp_b[5];
                temp_c[1][6] += temp_a[1] * temp_b[6];
                temp_c[1][7] += temp_a[1] * temp_b[7];

                temp_c[2][0] += temp_a[2] * temp_b[0];
                temp_c[2][1] += temp_a[2] * temp_b[1];
                temp_c[2][2] += temp_a[2] * temp_b[2];
                temp_c[2][3] += temp_a[2] * temp_b[3];
                temp_c[2][4] += temp_a[2] * temp_b[4];
                temp_c[2][5] += temp_a[2] * temp_b[5];
                temp_c[2][6] += temp_a[2] * temp_b[6];
                temp_c[2][7] += temp_a[2] * temp_b[7];

                temp_c[3][0] += temp_a[3] * temp_b[0];
                temp_c[3][1] += temp_a[3] * temp_b[1];
                temp_c[3][2] += temp_a[3] * temp_b[2];
                temp_c[3][3] += temp_a[3] * temp_b[3];
                temp_c[3][4] += temp_a[3] * temp_b[4];
                temp_c[3][5] += temp_a[3] * temp_b[5];
                temp_c[3][6] += temp_a[3] * temp_b[6];
                temp_c[3][7] += temp_a[3] * temp_b[7];

                temp_c[4][0] += temp_a[4] * temp_b[0];
                temp_c[4][1] += temp_a[4] * temp_b[1];
                temp_c[4][2] += temp_a[4] * temp_b[2];
                temp_c[4][3] += temp_a[4] * temp_b[3];
                temp_c[4][4] += temp_a[4] * temp_b[4];
                temp_c[4][5] += temp_a[4] * temp_b[5];
                temp_c[4][6] += temp_a[4] * temp_b[6];
                temp_c[4][7] += temp_a[4] * temp_b[7];

                temp_c[5][0] += temp_a[5] * temp_b[0];
                temp_c[5][1] += temp_a[5] * temp_b[1];
                temp_c[5][2] += temp_a[5] * temp_b[2];
                temp_c[5][3] += temp_a[5] * temp_b[3];
                temp_c[5][4] += temp_a[5] * temp_b[4];
                temp_c[5][5] += temp_a[5] * temp_b[5];
                temp_c[5][6] += temp_a[5] * temp_b[6];
                temp_c[5][7] += temp_a[5] * temp_b[7];

                temp_c[6][0] += temp_a[6] * temp_b[0];
                temp_c[6][1] += temp_a[6] * temp_b[1];
                temp_c[6][2] += temp_a[6] * temp_b[2];
                temp_c[6][3] += temp_a[6] * temp_b[3];
                temp_c[6][4] += temp_a[6] * temp_b[4];
                temp_c[6][5] += temp_a[6] * temp_b[5];
                temp_c[6][6] += temp_a[6] * temp_b[6];
                temp_c[6][7] += temp_a[6] * temp_b[7];

                temp_c[7][0] += temp_a[7] * temp_b[0];
                temp_c[7][1] += temp_a[7] * temp_b[1];
                temp_c[7][2] += temp_a[7] * temp_b[2];
                temp_c[7][3] += temp_a[7] * temp_b[3];
                temp_c[7][4] += temp_a[7] * temp_b[4];
                temp_c[7][5] += temp_a[7] * temp_b[5];
                temp_c[7][6] += temp_a[7] * temp_b[6];
                temp_c[7][7] += temp_a[7] * temp_b[7];
            }

            c->blob[(i+0)*n+j+0] = temp_c[0][0];
            c->blob[(i+0)*n+j+1] = temp_c[0][1];
            c->blob[(i+0)*n+j+2] = temp_c[0][2];
            c->blob[(i+0)*n+j+3] = temp_c[0][3];
            c->blob[(i+0)*n+j+4] = temp_c[0][4];
            c->blob[(i+0)*n+j+5] = temp_c[0][5];
            c->blob[(i+0)*n+j+6] = temp_c[0][6];
            c->blob[(i+0)*n+j+7] = temp_c[0][7];

            c->blob[(i+1)*n+j+0] = temp_c[1][0];
            c->blob[(i+1)*n+j+1] = temp_c[1][1];
            c->blob[(i+1)*n+j+2] = temp_c[1][2];
            c->blob[(i+1)*n+j+3] = temp_c[1][3];
            c->blob[(i+1)*n+j+4] = temp_c[1][4];
            c->blob[(i+1)*n+j+5] = temp_c[1][5];
            c->blob[(i+1)*n+j+6] = temp_c[1][6];
            c->blob[(i+1)*n+j+7] = temp_c[1][7];

            c->blob[(i+2)*n+j+0] = temp_c[2][0];
            c->blob[(i+2)*n+j+1] = temp_c[2][1];
            c->blob[(i+2)*n+j+2] = temp_c[2][2];
            c->blob[(i+2)*n+j+3] = temp_c[2][3];
            c->blob[(i+2)*n+j+4] = temp_c[2][4];
            c->blob[(i+2)*n+j+5] = temp_c[2][5];
            c->blob[(i+2)*n+j+6] = temp_c[2][6];
            c->blob[(i+2)*n+j+7] = temp_c[2][7];

            c->blob[(i+3)*n+j+0] = temp_c[3][0];
            c->blob[(i+3)*n+j+1] = temp_c[3][1];
            c->blob[(i+3)*n+j+2] = temp_c[3][2];
            c->blob[(i+3)*n+j+3] = temp_c[3][3];
            c->blob[(i+3)*n+j+4] = temp_c[3][4];
            c->blob[(i+3)*n+j+5] = temp_c[3][5];
            c->blob[(i+3)*n+j+6] = temp_c[3][6];
            c->blob[(i+3)*n+j+7] = temp_c[3][7];

            c->blob[(i+4)*n+j+0] = temp_c[4][0];
            c->blob[(i+4)*n+j+1] = temp_c[4][1];
            c->blob[(i+4)*n+j+2] = temp_c[4][2];
            c->blob[(i+4)*n+j+3] = temp_c[4][3];
            c->blob[(i+4)*n+j+4] = temp_c[4][4];
            c->blob[(i+4)*n+j+5] = temp_c[4][5];
            c->blob[(i+4)*n+j+6] = temp_c[4][6];
            c->blob[(i+4)*n+j+7] = temp_c[4][7];

            c->blob[(i+5)*n+j+0] = temp_c[5][0];
            c->blob[(i+5)*n+j+1] = temp_c[5][1];
            c->blob[(i+5)*n+j+2] = temp_c[5][2];
            c->blob[(i+5)*n+j+3] = temp_c[5][3];
            c->blob[(i+5)*n+j+4] = temp_c[5][4];
            c->blob[(i+5)*n+j+5] = temp_c[5][5];
            c->blob[(i+5)*n+j+6] = temp_c[5][6];
            c->blob[(i+5)*n+j+7] = temp_c[5][7];

            c->blob[(i+6)*n+j+0] = temp_c[6][0];
            c->blob[(i+6)*n+j+1] = temp_c[6][1];
            c->blob[(i+6)*n+j+2] = temp_c[6][2];
            c->blob[(i+6)*n+j+3] = temp_c[6][3];
            c->blob[(i+6)*n+j+4] = temp_c[6][4];
            c->blob[(i+6)*n+j+5] = temp_c[6][5];
            c->blob[(i+6)*n+j+6] = temp_c[6][6];
            c->blob[(i+6)*n+j+7] = temp_c[6][7];

            c->blob[(i+7)*n+j+0] = temp_c[7][0];
            c->blob[(i+7)*n+j+1] = temp_c[7][1];
            c->blob[(i+7)*n+j+2] = temp_c[7][2];
            c->blob[(i+7)*n+j+3] = temp_c[7][3];
            c->blob[(i+7)*n+j+4] = temp_c[7][4];
            c->blob[(i+7)*n+j+5] = temp_c[7][5];
            c->blob[(i+7)*n+j+6] = temp_c[7][6];
            c->blob[(i+7)*n+j+7] = temp_c[7][7];
        }
    }
}

static void matrix_mult_5(matrix *a, matrix *b, matrix *c)
{
    int i, j, p;
    int m = a->rows;
    int n = b->cols;
    int k = a->cols;

    for (i=0; i<m; i+=8) {
        
        for (j=0; j<n; j+=4) {

            mat_dtype temp_c[8][4];

            for (p=0; p<k; p++) {

                mat_dtype temp_a[8];
                mat_dtype temp_b[4];

                temp_a[0] = a->blob[i*k+p+0];
                temp_a[1] = a->blob[i*k+p+1];
                temp_a[2] = a->blob[i*k+p+2];
                temp_a[3] = a->blob[i*k+p+3];
                temp_a[4] = a->blob[i*k+p+4];
                temp_a[5] = a->blob[i*k+p+5];
                temp_a[6] = a->blob[i*k+p+6];
                temp_a[7] = a->blob[i*k+p+7];

                temp_b[0] = b->blob[p*n+j+0];
                temp_b[1] = b->blob[p*n+j+1];
                temp_b[2] = b->blob[p*n+j+2];
                temp_b[3] = b->blob[p*n+j+3];

                temp_c[0][0] += temp_a[0] * temp_b[0];
                temp_c[0][1] += temp_a[0] * temp_b[1];
                temp_c[0][2] += temp_a[0] * temp_b[2];
                temp_c[0][3] += temp_a[0] * temp_b[3];

                temp_c[1][0] += temp_a[1] * temp_b[0];
                temp_c[1][1] += temp_a[1] * temp_b[1];
                temp_c[1][2] += temp_a[1] * temp_b[2];
                temp_c[1][3] += temp_a[1] * temp_b[3];

                temp_c[2][0] += temp_a[2] * temp_b[0];
                temp_c[2][1] += temp_a[2] * temp_b[1];
                temp_c[2][2] += temp_a[2] * temp_b[2];
                temp_c[2][3] += temp_a[2] * temp_b[3];

                temp_c[3][0] += temp_a[3] * temp_b[0];
                temp_c[3][1] += temp_a[3] * temp_b[1];
                temp_c[3][2] += temp_a[3] * temp_b[2];
                temp_c[3][3] += temp_a[3] * temp_b[3];

                temp_c[4][0] += temp_a[4] * temp_b[0];
                temp_c[4][1] += temp_a[4] * temp_b[1];
                temp_c[4][2] += temp_a[4] * temp_b[2];
                temp_c[4][3] += temp_a[4] * temp_b[3];

                temp_c[5][0] += temp_a[5] * temp_b[0];
                temp_c[5][1] += temp_a[5] * temp_b[1];
                temp_c[5][2] += temp_a[5] * temp_b[2];
                temp_c[5][3] += temp_a[5] * temp_b[3];

                temp_c[6][0] += temp_a[6] * temp_b[0];
                temp_c[6][1] += temp_a[6] * temp_b[1];
                temp_c[6][2] += temp_a[6] * temp_b[2];
                temp_c[6][3] += temp_a[6] * temp_b[3];

                temp_c[7][0] += temp_a[7] * temp_b[0];
                temp_c[7][1] += temp_a[7] * temp_b[1];
                temp_c[7][2] += temp_a[7] * temp_b[2];
                temp_c[7][3] += temp_a[7] * temp_b[3];
            }

            c->blob[(i+0)*n+j+0] = temp_c[0][0];
            c->blob[(i+0)*n+j+1] = temp_c[0][1];
            c->blob[(i+0)*n+j+2] = temp_c[0][2];
            c->blob[(i+0)*n+j+3] = temp_c[0][3];
            c->blob[(i+0)*n+j+4] = temp_c[0][4];
            c->blob[(i+0)*n+j+5] = temp_c[0][5];
            c->blob[(i+0)*n+j+6] = temp_c[0][6];
            c->blob[(i+0)*n+j+7] = temp_c[0][7];

            c->blob[(i+1)*n+j+0] = temp_c[1][0];
            c->blob[(i+1)*n+j+1] = temp_c[1][1];
            c->blob[(i+1)*n+j+2] = temp_c[1][2];
            c->blob[(i+1)*n+j+3] = temp_c[1][3];
            c->blob[(i+1)*n+j+4] = temp_c[1][4];
            c->blob[(i+1)*n+j+5] = temp_c[1][5];
            c->blob[(i+1)*n+j+6] = temp_c[1][6];
            c->blob[(i+1)*n+j+7] = temp_c[1][7];

            c->blob[(i+2)*n+j+0] = temp_c[2][0];
            c->blob[(i+2)*n+j+1] = temp_c[2][1];
            c->blob[(i+2)*n+j+2] = temp_c[2][2];
            c->blob[(i+2)*n+j+3] = temp_c[2][3];
            c->blob[(i+2)*n+j+4] = temp_c[2][4];
            c->blob[(i+2)*n+j+5] = temp_c[2][5];
            c->blob[(i+2)*n+j+6] = temp_c[2][6];
            c->blob[(i+2)*n+j+7] = temp_c[2][7];

            c->blob[(i+3)*n+j+0] = temp_c[3][0];
            c->blob[(i+3)*n+j+1] = temp_c[3][1];
            c->blob[(i+3)*n+j+2] = temp_c[3][2];
            c->blob[(i+3)*n+j+3] = temp_c[3][3];
            c->blob[(i+3)*n+j+4] = temp_c[3][4];
            c->blob[(i+3)*n+j+5] = temp_c[3][5];
            c->blob[(i+3)*n+j+6] = temp_c[3][6];
            c->blob[(i+3)*n+j+7] = temp_c[3][7];
        }
    }
}

static void matrix_mult_4(matrix *a, matrix *b, matrix *c)
{
    int i, j, p;
    int m = a->rows;
    int n = b->cols;
    int k = a->cols;

    for (i=0; i<m; i+=4) {
        
        for (j=0; j<n; j+=8) {

            mat_dtype temp_c[4][8];

            for (p=0; p<k; p++) {

                mat_dtype temp_a[4];
                mat_dtype temp_b[8];

                temp_a[0] = a->blob[i*k+p+0];
                temp_a[1] = a->blob[i*k+p+1];
                temp_a[2] = a->blob[i*k+p+2];
                temp_a[3] = a->blob[i*k+p+3];
                temp_b[0] = b->blob[p*n+j+0];
                temp_b[1] = b->blob[p*n+j+1];
                temp_b[2] = b->blob[p*n+j+2];
                temp_b[3] = b->blob[p*n+j+3];
                temp_b[4] = b->blob[p*n+j+4];
                temp_b[5] = b->blob[p*n+j+5];
                temp_b[6] = b->blob[p*n+j+6];
                temp_b[7] = b->blob[p*n+j+7];

                temp_c[0][0] += temp_a[0] * temp_b[0];
                temp_c[0][1] += temp_a[0] * temp_b[1];
                temp_c[0][2] += temp_a[0] * temp_b[2];
                temp_c[0][3] += temp_a[0] * temp_b[3];
                temp_c[0][4] += temp_a[0] * temp_b[4];
                temp_c[0][5] += temp_a[0] * temp_b[5];
                temp_c[0][6] += temp_a[0] * temp_b[6];
                temp_c[0][7] += temp_a[0] * temp_b[7];

                temp_c[1][0] += temp_a[1] * temp_b[0];
                temp_c[1][1] += temp_a[1] * temp_b[1];
                temp_c[1][2] += temp_a[1] * temp_b[2];
                temp_c[1][3] += temp_a[1] * temp_b[3];
                temp_c[1][4] += temp_a[1] * temp_b[4];
                temp_c[1][5] += temp_a[1] * temp_b[5];
                temp_c[1][6] += temp_a[1] * temp_b[6];
                temp_c[1][7] += temp_a[1] * temp_b[7];

                temp_c[2][0] += temp_a[2] * temp_b[0];
                temp_c[2][1] += temp_a[2] * temp_b[1];
                temp_c[2][2] += temp_a[2] * temp_b[2];
                temp_c[2][3] += temp_a[2] * temp_b[3];
                temp_c[2][4] += temp_a[2] * temp_b[4];
                temp_c[2][5] += temp_a[2] * temp_b[5];
                temp_c[2][6] += temp_a[2] * temp_b[6];
                temp_c[2][7] += temp_a[2] * temp_b[7];

                temp_c[3][0] += temp_a[3] * temp_b[0];
                temp_c[3][1] += temp_a[3] * temp_b[1];
                temp_c[3][2] += temp_a[3] * temp_b[2];
                temp_c[3][3] += temp_a[3] * temp_b[3];
                temp_c[3][4] += temp_a[3] * temp_b[4];
                temp_c[3][5] += temp_a[3] * temp_b[5];
                temp_c[3][6] += temp_a[3] * temp_b[6];
                temp_c[3][7] += temp_a[3] * temp_b[7];
            }

            c->blob[(i+0)*n+j+0] = temp_c[0][0];
            c->blob[(i+0)*n+j+1] = temp_c[0][1];
            c->blob[(i+0)*n+j+2] = temp_c[0][2];
            c->blob[(i+0)*n+j+3] = temp_c[0][3];
            c->blob[(i+0)*n+j+4] = temp_c[0][4];
            c->blob[(i+0)*n+j+5] = temp_c[0][5];
            c->blob[(i+0)*n+j+6] = temp_c[0][6];
            c->blob[(i+0)*n+j+7] = temp_c[0][7];

            c->blob[(i+1)*n+j+0] = temp_c[1][0];
            c->blob[(i+1)*n+j+1] = temp_c[1][1];
            c->blob[(i+1)*n+j+2] = temp_c[1][2];
            c->blob[(i+1)*n+j+3] = temp_c[1][3];
            c->blob[(i+1)*n+j+4] = temp_c[1][4];
            c->blob[(i+1)*n+j+5] = temp_c[1][5];
            c->blob[(i+1)*n+j+6] = temp_c[1][6];
            c->blob[(i+1)*n+j+7] = temp_c[1][7];

            c->blob[(i+2)*n+j+0] = temp_c[2][0];
            c->blob[(i+2)*n+j+1] = temp_c[2][1];
            c->blob[(i+2)*n+j+2] = temp_c[2][2];
            c->blob[(i+2)*n+j+3] = temp_c[2][3];
            c->blob[(i+2)*n+j+4] = temp_c[2][4];
            c->blob[(i+2)*n+j+5] = temp_c[2][5];
            c->blob[(i+2)*n+j+6] = temp_c[2][6];
            c->blob[(i+2)*n+j+7] = temp_c[2][7];

            c->blob[(i+3)*n+j+0] = temp_c[3][0];
            c->blob[(i+3)*n+j+1] = temp_c[3][1];
            c->blob[(i+3)*n+j+2] = temp_c[3][2];
            c->blob[(i+3)*n+j+3] = temp_c[3][3];
            c->blob[(i+3)*n+j+4] = temp_c[3][4];
            c->blob[(i+3)*n+j+5] = temp_c[3][5];
            c->blob[(i+3)*n+j+6] = temp_c[3][6];
            c->blob[(i+3)*n+j+7] = temp_c[3][7];
        }
    }
}

static void matrix_mult_3(matrix *a, matrix *b, matrix *c)
{
    int i, j, p;
    int m = a->rows;
    int n = b->cols;
    int k = a->cols;

    for (i=0; i<m; i+=4) {
        
        for (j=0; j<n; j+=4) {

            mat_dtype temp_c[4][4];

            for (p=0; p<k; p++) {

                mat_dtype temp_a[4];
                mat_dtype temp_b[4];

                temp_a[0] = a->blob[i*k+p+0];
                temp_a[1] = a->blob[i*k+p+1];
                temp_a[2] = a->blob[i*k+p+2];
                temp_a[3] = a->blob[i*k+p+3];
                temp_b[0] = b->blob[p*n+j+0];
                temp_b[1] = b->blob[p*n+j+1];
                temp_b[2] = b->blob[p*n+j+2];
                temp_b[3] = b->blob[p*n+j+3];

                temp_c[0][0] += temp_a[0] * temp_b[0];
                temp_c[0][1] += temp_a[0] * temp_b[1];
                temp_c[0][2] += temp_a[0] * temp_b[2];
                temp_c[0][3] += temp_a[0] * temp_b[3];
                temp_c[1][0] += temp_a[1] * temp_b[0];
                temp_c[1][1] += temp_a[1] * temp_b[1];
                temp_c[1][2] += temp_a[1] * temp_b[2];
                temp_c[1][3] += temp_a[1] * temp_b[3];
                temp_c[2][0] += temp_a[2] * temp_b[0];
                temp_c[2][1] += temp_a[2] * temp_b[1];
                temp_c[2][2] += temp_a[2] * temp_b[2];
                temp_c[2][3] += temp_a[2] * temp_b[3];
                temp_c[3][0] += temp_a[3] * temp_b[0];
                temp_c[3][1] += temp_a[3] * temp_b[1];
                temp_c[3][2] += temp_a[3] * temp_b[2];
                temp_c[3][3] += temp_a[3] * temp_b[3];
            }

            c->blob[(i+0)*n+j+0] = temp_c[0][0];
            c->blob[(i+0)*n+j+1] = temp_c[0][1];
            c->blob[(i+0)*n+j+2] = temp_c[0][2];
            c->blob[(i+0)*n+j+3] = temp_c[0][3];
            c->blob[(i+1)*n+j+0] = temp_c[1][0];
            c->blob[(i+1)*n+j+1] = temp_c[1][1];
            c->blob[(i+1)*n+j+2] = temp_c[1][2];
            c->blob[(i+1)*n+j+3] = temp_c[1][3];
            c->blob[(i+2)*n+j+0] = temp_c[2][0];
            c->blob[(i+2)*n+j+1] = temp_c[2][1];
            c->blob[(i+2)*n+j+2] = temp_c[2][2];
            c->blob[(i+2)*n+j+3] = temp_c[2][3];
            c->blob[(i+3)*n+j+0] = temp_c[3][0];
            c->blob[(i+3)*n+j+1] = temp_c[3][1];
            c->blob[(i+3)*n+j+2] = temp_c[3][2];
            c->blob[(i+3)*n+j+3] = temp_c[3][3];
        }
    }
}

/**
 * @brief  Matrix Mult v1
 * split the calculation into 1x4 small blocks, and perform 
 * operations on one row of matrix A and four columns of matrix B.
 * utilize the program locality of matrix B to increase the cache 
 * hit of matrix B
 * 
 * @param[in]  a: Matrix A
 * @param[in]  b: Matrix B
 * @param[out] c: Matrix C
 */
static void matrix_mult_2(matrix *a, matrix *b, matrix *c)
{
    int i, j, p;
    int m = a->rows;
    int n = b->cols;
    int k = a->cols;

    for (i=0; i<m; i++) {
        
        for (j=0; j<n; j+=4) {
            
            mat_dtype temp[4];

            for (p=0; p<k; p++) {
                mat_dtype temp_a = a->blob[i*k+p];
                temp[0] += temp_a * b->blob[p*n+j+0];
                temp[1] += temp_a * b->blob[p*n+j+1];
                temp[2] += temp_a * b->blob[p*n+j+2];
                temp[3] += temp_a * b->blob[p*n+j+3];
            }

            c->blob[i*n+j+0] = temp[0];
            c->blob[i*n+j+1] = temp[1];
            c->blob[i*n+j+2] = temp[2];
            c->blob[i*n+j+3] = temp[3];
        }
    }
}

/**
 * @brief  Matrix Mult v1
 * use local variables 'temp' to store temporary accumulated values 
 * of internal loops, and finally write the accumulated value to C[i,j].
 * reduced memory access to C, the memory access performance for 
 * temporary variables is faster than data in heap
 * 
 * @param[in]  a: Matrix A
 * @param[in]  b: Matrix B
 * @param[out] c: Matrix C
 */
static void matrix_mult_1(matrix *a, matrix *b, matrix *c)
{
    int i, j, p;
    int m = a->rows;
    int n = b->cols;
    int k = a->cols;

    for (i=0; i<m; i++) {
        for (j=0; j<n; j++) {
            mat_dtype temp = 0;
            for (p=0; p<k; p++) {
                temp += a->blob[i*k+p] * b->blob[p*n+j];
            }
            c->blob[i*n+j] = temp;
        }
    }
}

/**
 * @brief  Matrix Mult v0
 * 3 layers loop foreach every row and col in a and b
 *  
 * @param[in]  a: Matrix A
 * @param[in]  b: Matrix B
 * @param[out] c: Matrix C
 */
static void matrix_mult_0(matrix *a, matrix *b, matrix *c)
{
    int i, j, p;
    int m = a->rows;
    int n = b->cols;
    int k = a->cols;

    for (i=0; i<m; i++) {
        for (j=0; j<n; j++) {
            for (p=0; p<k; p++) {
#if (GEMM_MULT_DUMP == 1)
                gemm_debug("C[%d,%d]+=A[%d,%d]*B[%d,%d]", i, j, i, p, p, j);
#endif
                c->blob[i*n+j] += a->blob[i*k+p] * b->blob[p*n+j];
#if (GEMM_MULT_DUMP == 1)
                gemm_debug("C[%d,%d]:%d A[%d,%d]:%d B[%d,%d]:%d", 
                    i, j, c->blob[i*n+j], 
                    i, p, a->blob[i*k+p], 
                    p, j, b->blob[p*n+j]);
#endif
            }
        }
    }
}

static gemm_handle gemm_op_tb[GEMM_OP_SIZEOF] = {
    {"op0", matrix_mult_0, "3 loop base"},
    {"op1", matrix_mult_1, "local temp value"},
    {"op2", matrix_mult_2, "1x4 block"},
    {"op3", matrix_mult_3, "4x4 block"},
    {"op4", matrix_mult_4, "4x8 block"},
    {"op5", matrix_mult_5, "8x4 block"},
    {"op6", matrix_mult_6, "8x8 block"},
#if defined(CONFIG_X86_AVX)
    {"op7", matrix_mult_7, "AVX 1x8 block"},
#endif
};

static void usage(void)
{
    printf("\n");
    printf("gemm_op <cmd> <opt>\n");
    printf("  -h,?,--help                print help information\n");
    printf("  -l,--log <level>           set log level(0~5)\n");
    printf("  -o,--op <op>               set gemm op(0~%d default 0)\n", GEMM_OP_SIZEOF);
    printf("  -m,--shape-m <m>           set shape m(default 512)\n");
    printf("  -n,--shape-n <n>           set shape n(default 512)\n");
    printf("  -k,--shape-k <k>           set shape k(default 512)\n");
    printf("  -d,--dump-matrix           dump matrix A B C\n");
    printf("  -s,--size <size>           set A B C with the same size\n");
    printf("  -v,--verbose               verbose in each epochs\n");
    printf("\n");
}

int main(int argc, char *argv[])
{
    int e;
    int m           = 320;
    int n           = 192;
    int k           = 256;
    int op          = GEMM_OP_0;
    int opt         = 0;
    int size        = 0;
    int epochs      = GEMM_EPOCHS;
    int verbose     = 0;
    int dump_matrix = 0;
    qe_u8 loglevel  = QELOG_INFO;
    qe_u32 t1, t2;
    matrix *mat_a, *mat_b, *mat_c;
    gemm_metric metric;

    qelog_init(loglevel, 
        QELOG_LV|QELOG_DM|QELOG_CL);

    static const struct option long_opts[] = {
        {"shape-m",     required_argument, QE_NULL, 'm'},
        {"shape-n",     required_argument, QE_NULL, 'n'},
        {"shape-k",     required_argument, QE_NULL, 'k'},
        {"op",          required_argument, QE_NULL, 'o'},
        {"log",         required_argument, QE_NULL, 'l'},
        {"epoch",       required_argument, QE_NULL, 'e'},
        {"dump-matrix", no_argument,       QE_NULL, 'd'},
        {"verbose",     no_argument,       QE_NULL, 'v'},
        {"size",        required_argument, QE_NULL, 's'},
        {"help",        no_argument,       QE_NULL, 'h'},
    };

    while ((opt = getopt_long(argc, argv, "o:l:m:n:k:e:s:dv?h-", long_opts, NULL)) != -1) {
        
        switch (opt) {

        case 'l':
            loglevel = atoi(optarg);
            qelog_set_level(loglevel);
            break;

        case 'o':
            op = atoi(optarg);
            break;

        case 'm':
            m = atoi(optarg);
            break;

        case 'n':
            n = atoi(optarg);
            break;

        case 'k':
            k = atoi(optarg);
            break;        

        case 'e':
            epochs= atoi(optarg);
            break;

        case 'd':
            dump_matrix = 1;
            break;

        case 'v':
            verbose = 1;
            break;

        case 's':
            size = atoi(optarg);
            break;

        case 'h':
        case '?':
        default:
            usage();
            exit(EXIT_SUCCESS);
        }
    }

    if (size) {
        m = size;
        n = size;
        k = size;
    }

    gemm_info("GEMM Optimize Benchmark");
    gemm_info("Matrix A(%d,%d) B(%d,%d) C(%d,%d) %d", m, k, k, n, m, n, sizeof(mat_dtype));
    gemm_info("%s %s", gemm_op_tb[op].name, gemm_op_tb[op].desc);

    mat_a = matrix_create(m, k);
    qe_assert(mat_a != QE_NULL);
    mat_b = matrix_create(k, n);
    qe_assert(mat_b != QE_NULL);
    mat_c = matrix_create(m, n);
    qe_assert(mat_c != QE_NULL);

    matrix_set_zero(mat_c);
    matrix_set_random(mat_a);
    matrix_set_random(mat_b);

    if (dump_matrix) {
        gemm_info("matrix A:");
        matrix_dump(mat_a);
        gemm_info("matrix B:");
        matrix_dump(mat_b);
    }

    gemm_metric_init(&metric, epochs);

    gemm_debug("benchmark begin...");

    for (e=0; e<epochs; e++) {

        t1 = qe_time_ms();
        gemm_op_tb[op].impl(mat_a, mat_b, mat_c);
        t2 = qe_time_ms();
        if (verbose) {
            gemm_info("epoch %03d ctime %dms", e, t2-t1);
        }
        gemm_metric_update(&metric, e, t2-t1);
    }
    
    gemm_debug("benchmark stop...");

    if (dump_matrix) {
        gemm_info("matrix C:");
        matrix_dump(mat_c);
    }

    if (epochs > 1) {
        gemm_metric_static(&metric);
        gemm_info("%s %s min:%dms max:%dms mean:%dms", 
            gemm_op_tb[op].name, gemm_op_tb[op].desc, 
            metric.min_ctime, metric.max_ctime, metric.mean_ctime);
    } else {
        gemm_info("%s %s ctime:%dms", gemm_op_tb[op].name, 
            gemm_op_tb[op].desc, t2-t1);
    }

    gemm_info("GEMM Optimize finish");

    qe_free(mat_a);
    qe_free(mat_b);
    qe_free(mat_c);
    gemm_metric_destory(&metric);

    return 0;
}